10. 训练网络

07 训练网络 V1

交叉熵损失

PyTorch 文档提到,交叉熵损失函数包括两步:

  • 首先对看到的任何输出应用 softmax 函数
  • 然后应用 NLLLoss 负对数似然损失

接着返回一批数据的平均损失。因为交叉熵损失会应用 softmax 函数,所以我们不需要在模型定义的 forward 函数中应用 softmax 函数;我们还可以采用另一种方式。

另一种方式

我们可以将 softmax 步骤和 NLLLoss 步骤分开处理。

  • 在模型的 forward 函数中,我们可以向输出 x 应用 softmax 激活函数。
 ...
 ...
# a softmax layer to convert 10 outputs into a distribution of class probabilities
x = F.log_softmax(x, dim=1)

return x
  • 然后,在定义损失函数时应用 NLLLoss
# cross entropy loss combines softmax and nn.NLLLoss() in one single class
# here, we've separated them
criterion = nn.NLLLoss()

这样会将常规的 criterion = nn.CrossEntropy() 分成两步:softmax 和NLLLoss;如果你希望模型输出是类别概率,而不是类别分数的话,可以采取这种方式。